#include <asnumpy/math/exponents_and_logarithms.hpp>
#include <asnumpy/utils/npu_array.hpp>
#include <asnumpy/utils/status_handler.hpp>

#include <acl/acl.h>
#include <aclnn/aclnn_base.h>
#include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_expm1.h>
#include <aclnnop/aclnn_exp2.h>
#include <aclnnop/aclnn_log.h>
#include <aclnnop/aclnn_log10.h>
#include <aclnnop/aclnn_log2.h>
#include <aclnnop/aclnn_log1p.h>
#include <aclnnop/aclnn_logaddexp.h>
#include <aclnnop/aclnn_logaddexp2.h>

#include <fmt/base.h>
#include <fmt/format.h>
#include <stdexcept>

NPUArray Exp(const NPUArray& x, py::dtype dtype) {
    auto shape = x.shape;
    auto result = NPUArray(shape, dtype);
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor;
    auto error = aclnnExpGetWorkspaceSize(x.tensorPtr, result.tensorPtr, &workspaceSize, &executor);
    CheckGetWorkspaceSizeAclnnStatus(error);
    void* workspaceAddr = nullptr;
    if(workspaceSize > 0) {
        error = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CheckMallocAclnnStatus(error);
    }
    error = aclnnExp(workspaceAddr, workspaceSize, executor, nullptr);
    CheckAclnnStatus(error, "aclnnExp error");
    error = aclrtSynchronizeDevice();
    CheckSynchronizeDeviceAclnnStatus(error);
    return result;
}

NPUArray Expm1(const NPUArray& x, py::dtype dtype) {
    auto shape = x.shape;
    auto result = NPUArray(shape, dtype);
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor;
    auto error = aclnnExpm1GetWorkspaceSize(x.tensorPtr, result.tensorPtr, &workspaceSize, &executor);
    CheckGetWorkspaceSizeAclnnStatus(error);
    void* workspaceAddr = nullptr;
    if(workspaceSize > 0) {
        error = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CheckMallocAclnnStatus(error);
    }
    error = aclnnExpm1(workspaceAddr, workspaceSize, executor, nullptr);
    CheckAclnnStatus(error, "aclnnExpm1 error");
    error = aclrtSynchronizeDevice();
    CheckSynchronizeDeviceAclnnStatus(error);
    return result;
}

NPUArray Exp2(const NPUArray& x, py::dtype dtype) {
    auto shape = x.shape;
    auto result = NPUArray(shape, dtype);
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor;
    auto error = aclnnExp2GetWorkspaceSize(x.tensorPtr, result.tensorPtr, &workspaceSize, &executor);
    CheckGetWorkspaceSizeAclnnStatus(error);
    void* workspaceAddr = nullptr;
    if(workspaceSize > 0) {
        error = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CheckMallocAclnnStatus(error);
    }
    error = aclnnExp2(workspaceAddr, workspaceSize, executor, nullptr);
    CheckAclnnStatus(error, "aclnnExp2 error");
    error = aclrtSynchronizeDevice();
    CheckSynchronizeDeviceAclnnStatus(error);
    return result;
}

NPUArray Log(const NPUArray& x, py::dtype dtype) {
    auto shape = x.shape;
    auto result = NPUArray(shape, dtype);
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor;
    auto error = aclnnLogGetWorkspaceSize(x.tensorPtr, result.tensorPtr, &workspaceSize, &executor);
    CheckGetWorkspaceSizeAclnnStatus(error);
    void* workspaceAddr = nullptr;
    if(workspaceSize > 0) {
        error = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CheckMallocAclnnStatus(error);
    }
    error = aclnnLog(workspaceAddr, workspaceSize, executor, nullptr);
    CheckAclnnStatus(error, "aclnnLog error");
    error = aclrtSynchronizeDevice();
    CheckSynchronizeDeviceAclnnStatus(error);
    return result;
}

NPUArray Log10(const NPUArray& x, py::dtype dtype) {
    auto shape = x.shape;
    auto result = NPUArray(shape, dtype);
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor;
    auto error = aclnnLog10GetWorkspaceSize(x.tensorPtr, result.tensorPtr, &workspaceSize, &executor);
    CheckGetWorkspaceSizeAclnnStatus(error);
    void* workspaceAddr = nullptr;
    if(workspaceSize > 0) {
        error = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CheckMallocAclnnStatus(error);
    }
    error = aclnnLog10(workspaceAddr, workspaceSize, executor, nullptr);
    CheckAclnnStatus(error, "aclnnLog10 error");
    error = aclrtSynchronizeDevice();
    CheckSynchronizeDeviceAclnnStatus(error);
    return result;
}

NPUArray Log2(const NPUArray& x, py::dtype dtype) {
    auto shape = x.shape;
    auto result = NPUArray(shape, dtype);
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor;
    auto error = aclnnLog2GetWorkspaceSize(x.tensorPtr, result.tensorPtr, &workspaceSize, &executor);
    CheckGetWorkspaceSizeAclnnStatus(error);
    void* workspaceAddr = nullptr;
    if(workspaceSize > 0) {
        error = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CheckMallocAclnnStatus(error);
    }
    error = aclnnLog2(workspaceAddr, workspaceSize, executor, nullptr);
    CheckAclnnStatus(error, "aclnnLog2 error");
    error = aclrtSynchronizeDevice();
    CheckSynchronizeDeviceAclnnStatus(error);
    return result;
}

NPUArray Log1p(const NPUArray& x, py::dtype dtype) {
    auto shape = x.shape;
    auto result = NPUArray(shape, dtype);
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor;
    auto error = aclnnLog1pGetWorkspaceSize(x.tensorPtr, result.tensorPtr, &workspaceSize, &executor);
    CheckGetWorkspaceSizeAclnnStatus(error);
    void* workspaceAddr = nullptr;
    if(workspaceSize > 0) {
        error = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CheckMallocAclnnStatus(error);
    }
    error = aclnnLog1p(workspaceAddr, workspaceSize, executor, nullptr);
    CheckAclnnStatus(error, "aclnnLog1p error");
    error = aclrtSynchronizeDevice();
    CheckSynchronizeDeviceAclnnStatus(error);
    return result;
}

NPUArray Logaddexp(const NPUArray& x1, const NPUArray& x2, py::dtype dtype) {
    auto broadcast = GetBroadcastShape(x1, x2);
    auto result = NPUArray(broadcast, dtype);
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor;
    auto error = aclnnLogAddExpGetWorkspaceSize(x1.tensorPtr, x2.tensorPtr, result.tensorPtr, &workspaceSize, &executor);
    CheckGetWorkspaceSizeAclnnStatus(error);
    void* workspaceAddr = nullptr;
    if(workspaceSize > 0) {
        error = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CheckMallocAclnnStatus(error);
    }
    error = aclnnLogAddExp(workspaceAddr, workspaceSize, executor, nullptr);
    CheckAclnnStatus(error, "aclnnLogAddExp error");
    error = aclrtSynchronizeDevice();
    CheckSynchronizeDeviceAclnnStatus(error);
    return result;
}

NPUArray Logaddexp2(const NPUArray& x1, const NPUArray& x2, py::dtype dtype) {
    auto broadcast = GetBroadcastShape(x1, x2);
    auto result = NPUArray(broadcast, dtype);
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor;
    auto error = aclnnLogAddExp2GetWorkspaceSize(x1.tensorPtr, x2.tensorPtr, result.tensorPtr, &workspaceSize, &executor);
    CheckGetWorkspaceSizeAclnnStatus(error);
    void* workspaceAddr = nullptr;
    if(workspaceSize > 0) {
        error = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CheckMallocAclnnStatus(error);
    }
    error = aclnnLogAddExp2(workspaceAddr, workspaceSize, executor, nullptr);
    CheckAclnnStatus(error, "aclnnLogAddExp2 error");
    error = aclrtSynchronizeDevice();
    CheckSynchronizeDeviceAclnnStatus(error);
    return result;
}